function [ Y ] = diag_nd( X )
%DIAG_ND 取多维矩阵前2维的对角线
%
% Input Arguments:
% X: 多维矩阵，前两维尺寸相同
%
% Output Arguments:
% Y: 维数比X少一的矩阵，第一位尺寸与X前两维尺寸相同，后续维数尺寸与X第三维及后续维数尺寸相同，第一维为X前两维的对角线

sz = size(X);
if length(sz) < 3
    Y = diag(X);
else
    preSecDimNumEl = sz(1, 1) * sz(1, 2);
    postSecDimNumEl = prod(sz(1, 3:end));
    rowSub = 1:sz(1, 1);
    linearInd = zeros(1, sz(1, 1)*postSecDimNumEl);
    baseLinearInd = sub2ind(sz(1, 1:2), rowSub, rowSub);
    for i=1:postSecDimNumEl
        linearInd(1, (i-1)*sz(1, 1)+(1:sz(1, 1))) = (i-1)*preSecDimNumEl + baseLinearInd;
    end
    Y = reshape(X(linearInd), sz(1, [1 3:end])); % NOTE: 使用线性索引
end
end